#include <time.h>
#include <stdio.h>
#include <stdint.h>
#include <math.h>

#include "forward.h"
#include "auxiliary.h"

// Define the pixel processing mode: 
// 0 = Original, 1 = E4M3, 2 = E5M2, 3 = BFloat16
#define PROCESS_PIXEL_MODE 0

// Type definitions for quantized formats
typedef uint8_t bfloat8;       // 8-bit float (E4M3)
typedef uint8_t float8_e5m2;   // 8-bit float (E5M2)
typedef uint16_t bfloat16;     // 16-bit float (BFloat16)
//Definitions for the quantizing functions are further down in this file.

// ---------------------------------------------------------------------
//               Simple get_time() helper
// ---------------------------------------------------------------------
static double get_time() {
    struct timespec ts;
    clock_gettime(CLOCK_MONOTONIC, &ts);
    return (double)ts.tv_sec + (double)ts.tv_nsec / 1e9;
}

// ---------------------------------------------------------------------
//            Top-level timing macros for Preprocess/Render
// ---------------------------------------------------------------------
#define TIME_START(name) double name##_start = get_time();
#define TIME_END(name, label) \
    do { \
      double name##_end = get_time(); \
      double name##_delta = (name##_end - name##_start); \
      printf("[TIMING] %s: %.6f seconds\n", label, name##_delta); \
      fflush(stdout); \
    } while(0)

// ---------------------------------------------------------------------
// EXCLUSIVE timing accumulators for tile/pixel/alpha
// ---------------------------------------------------------------------
static double g_tile_excl = 0.0;  // tile overhead
static double g_pixel_excl = 0.0;  // pixel overhead
static double g_alpha = 0.0;      // alpha loop

static int g_tiles_count = 0;    // number of tiles processed
static int g_pixels_count = 0;   // number of pixels processed
static int g_alpha_passes = 0;   // number of alpha loops processed

// ---------------------------------------------------------------------
// Function pointer for pixel processing selection
// ---------------------------------------------------------------------
static inline void (*selected_process_pixel)(
    uint32_t, uint32_t, int, int,
    const uint32_t *, const uint2 *,
    const float2 *, const float *,
    const float4 *, float *,
    uint32_t *, const float *,
    float *, int) = NULL;

// ---------------------------------------------------------------------
// Function to set the appropriate pixel processing function
// ---------------------------------------------------------------------
static void initialize_process_pixel_mode() {
    switch (PROCESS_PIXEL_MODE) {
        case 1: selected_process_pixel = process_pixel_bfloat8; break;
        case 2: selected_process_pixel = process_pixel_e5m2; break;
        case 3: selected_process_pixel = process_pixel_bfloat16; break;
        default: selected_process_pixel = process_pixel; break;
    }
}

// ---------------------------------------------------------------------
// Definitions for contstants and other helper functions
// ---------------------------------------------------------------------

const float SH_C0 = 0.28209479177387814f;
const float SH_C1 = 0.4886025119029199f;
const float SH_C2[] = {
    1.0925484305920792f,
    -1.0925484305920792f,
    0.31539156525252005f,
    -1.0925484305920792f,
    0.5462742152960396f
};
const float SH_C3[] = {
    -0.5900435899266435f,
    2.890611442640554f,
    -0.4570457994644658f,
    0.3731763325901154f,
    -0.4570457994644658f,
    1.445305721320277f,
    -0.5900435899266435f
};

static inline float cpu_rasterizer_ndc2Pix(float v, int S)
{
	return ((v + 1.0) * S - 1.0) * 0.5;
}

static inline int min_int(int a, int b) {
    return (a < b) ? a : b;
}

static inline float3 vec3_sub(float3 a, float3 b)
{
    float3 result = {a.x - b.x, a.y - b.y, a.z - b.z};
    return result;
}


static inline float3 vec3_div_scalar(float3 a, float scalar)
{
    float3 result = {a.x / scalar, a.y / scalar, a.z / scalar};
    return result;
}


static inline float3 vec3_mul_scalar(float3 a, float scalar)
{
    float3 result = {a.x * scalar, a.y * scalar, a.z * scalar};
    return result;
}


static inline float3 vec3_add(float3 a, float3 b)
{
    float3 result = {a.x + b.x, a.y + b.y, a.z + b.z};
    return result;
}


static inline float vec3_dot(float3 a, float3 b)
{
    return a.x * b.x + a.y * b.y + a.z * b.z;
}


static inline float vec3_length(float3 a)
{
    return sqrt(a.x * a.x + a.y * a.y + a.z * a.z);
}


static inline float3 vec3_max_scalar(float3 a, float scalar)
{
    float3 result = {fmax(a.x, scalar), fmax(a.y, scalar), fmax(a.z, scalar)};
    return result;
}


static inline void vec3_less_than_scalar(float3 a, float scalar, int *clamped)
{
    clamped[0] = (a.x < scalar) ? 1 : 0;
    clamped[1] = (a.y < scalar) ? 1 : 0;
    clamped[2] = (a.z < scalar) ? 1 : 0;
}

float3 computeColorFromSH(int idx, int deg, int max_coeffs, const float3 *means, float3 campos, const float *shs, int *clamped)
{
    float3 pos = means[idx];
    float3 dir = vec3_sub(pos, campos);
    float length = vec3_length(dir);
    dir = vec3_div_scalar(dir, length);

    const float3 *sh = (const float3 *)shs + idx * max_coeffs;
    float3 result = vec3_mul_scalar(sh[0], SH_C0);

    if (deg > 0)
    {
        float x = dir.x;
        float y = dir.y;
        float z = dir.z;

        result = vec3_sub(result,
                          vec3_mul_scalar(sh[1], SH_C1 * y));
        result = vec3_add(result,
                          vec3_mul_scalar(sh[2], SH_C1 * z));
        result = vec3_sub(result,
                          vec3_mul_scalar(sh[3], SH_C1 * x));

        if (deg > 1)
        {
            float xx = x * x, yy = y * y, zz = z * z;
            float xy = x * y, yz = y * z, xz = x * z;

            result = vec3_add(result,
                              vec3_mul_scalar(sh[4], SH_C2[0] * xy));
            result = vec3_add(result,
                              vec3_mul_scalar(sh[5], SH_C2[1] * yz));
            result = vec3_add(result,
                              vec3_mul_scalar(sh[6], SH_C2[2] * (2.0f * zz - xx - yy)));
            result = vec3_add(result,
                              vec3_mul_scalar(sh[7], SH_C2[3] * xz));
            result = vec3_add(result,
                              vec3_mul_scalar(sh[8], SH_C2[4] * (xx - yy)));

            if (deg > 2)
            {
                result = vec3_add(result,
                                  vec3_mul_scalar(sh[9], SH_C3[0] * y * (3.0f * xx - yy)));
                result = vec3_add(result,
                                  vec3_mul_scalar(sh[10], SH_C3[1] * xy * z));
                result = vec3_add(result,
                                  vec3_mul_scalar(sh[11], SH_C3[2] * y * (4.0f * zz - xx - yy)));
                result = vec3_add(result,
                                  vec3_mul_scalar(sh[12], SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy)));
                result = vec3_add(result,
                                  vec3_mul_scalar(sh[13], SH_C3[4] * x * (4.0f * zz - xx - yy)));
                result = vec3_add(result,
                                  vec3_mul_scalar(sh[14], SH_C3[5] * z * (xx - yy)));
                result = vec3_add(result,
                                  vec3_mul_scalar(sh[15], SH_C3[6] * x * (xx - 3.0f * yy)));
            }
        }
    }

    
    result = vec3_add(result, (float3){0.5f, 0.5f, 0.5f});

    
    clamped[3 * idx + 0] = (result.x < 0.0f) ? 1 : 0;
    clamped[3 * idx + 1] = (result.y < 0.0f) ? 1 : 0;
    clamped[3 * idx + 2] = (result.z < 0.0f) ? 1 : 0;

    
    result = vec3_max_scalar(result, 0.0f);

    return result;
}

static inline float3 computeCov2D(const float3 mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float *cov3D, const float *viewmatrix)
{
    float3 t = cpu_rasterizer_transformPoint4x3(mean, viewmatrix);

    float limx = 1.3f * tan_fovx;
    float limy = 1.3f * tan_fovy;
    float txtz = t.x / t.z;
    float tytz = t.y / t.z;
    t.x = fmin(limx, fmax(-limx, txtz)) * t.z;
    t.y = fmin(limy, fmax(-limy, tytz)) * t.z;

    float J[9] = {
        focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
        0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
        0.0f, 0.0f, 0.0f};

    float W[9] = {
        viewmatrix[0], viewmatrix[4], viewmatrix[8],
        viewmatrix[1], viewmatrix[5], viewmatrix[9],
        viewmatrix[2], viewmatrix[6], viewmatrix[10]};

    float T[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            T[i * 3 + j] = 0.0f;
            for (int k = 0; k < 3; k++)
            {
                T[i * 3 + j] += W[i * 3 + k] * J[k * 3 + j];
            }
        }
    }

    float Vrk[9] = {
        cov3D[0], cov3D[1], cov3D[2],
        cov3D[1], cov3D[3], cov3D[4],
        cov3D[2], cov3D[4], cov3D[5]};

    float TT[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            TT[i * 3 + j] = T[j * 3 + i];
        }
    }

    float Vrk_T[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            Vrk_T[i * 3 + j] = Vrk[j * 3 + i];
        }
    }

    float temp[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            temp[i * 3 + j] = 0.0f;
            for (int k = 0; k < 3; k++)
            {
                temp[i * 3 + j] += TT[i * 3 + k] * Vrk_T[k * 3 + j];
            }
        }
    }

    float cov[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            cov[i * 3 + j] = 0.0f;
            for (int k = 0; k < 3; k++)
            {
                cov[i * 3 + j] += temp[i * 3 + k] * T[k * 3 + j];
            }
        }
    }

    cov[0] += 0.3f;
    cov[4] += 0.3f;

    float3 result = {cov[0], cov[1], cov[4]};
    return result;
}

static inline void computeCov3D(const float3 scale, float mod, const float4 rot, float *cov3D)
{
    float S[9] = {
        mod * scale.x, 0.0f, 0.0f,
        0.0f, mod * scale.y, 0.0f,
        0.0f, 0.0f, mod * scale.z};

    float qx = rot.x;
    float qy = rot.y;
    float qz = rot.z;
    float qw = rot.w;

    float R[9] = {
        1.f - 2.f * (qy * qy + qz * qz), 2.f * (qx * qy - qw * qz), 2.f * (qx * qz + qw * qy),
        2.f * (qx * qy + qw * qz), 1.f - 2.f * (qx * qx + qz * qz), 2.f * (qy * qz - qw * qx),
        2.f * (qx * qz - qw * qy), 2.f * (qy * qz + qw * qx), 1.f - 2.f * (qx * qx + qy * qy)};

    float M[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            M[i * 3 + j] = 0.0f;
            for (int k = 0; k < 3; k++)
            {
                M[i * 3 + j] += S[i * 3 + k] * R[k * 3 + j];
            }
        }
    }

    float MT[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            MT[i * 3 + j] = M[j * 3 + i];
        }
    }

    float Sigma[9];
    for (int i = 0; i < 3; i++)
    {
        for (int j = 0; j < 3; j++)
        {
            Sigma[i * 3 + j] = 0.0f;
            for (int k = 0; k < 3; k++)
            {
                Sigma[i * 3 + j] += MT[i * 3 + k] * M[k * 3 + j];
            }
        }
    }

    
    cov3D[0] = Sigma[0];
    cov3D[1] = Sigma[1];
    cov3D[2] = Sigma[2];
    cov3D[3] = Sigma[4];
    cov3D[4] = Sigma[5];
    cov3D[5] = Sigma[8];
}


void preprocessCPU(int P, int D, int M,
    const float *orig_points,
    const float3 *scales,
    const float scale_modifier,
    const float4 *rotations,
    const float *opacities,
    const float *shs,
    int *clamped,
    const float *cov3D_precomp,
    const float *colors_precomp,
    const float *viewmatrix,
    const float *projmatrix,
    const float3 *cam_pos,
    const int W, int H,
    const float tan_fovx, float tan_fovy,
    const float focal_x, const float focal_y,
    int *radii,
    float2 *points_xy_image,
    float *depths,
    float *cov3Ds,
    float *rgb,
    float4 *conic_opacity,
    const dim3 grid,
    uint32_t *tiles_touched,
    int prefiltered)
{
for (int idx = 0; idx < P; idx++)
{
    radii[idx] = 0;
    tiles_touched[idx] = 0;

    float3 p_view;
    if (!cpu_rasterizer_in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, &p_view))
    continue;

    float3 p_orig = {orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2]};
    float4 p_hom = cpu_rasterizer_transformPoint4x4(p_orig, projmatrix);
    float p_w = 1.0f / (p_hom.w + 0.0000001f);
    float3 p_proj = {p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w};

    const float *cov3D;
    if (cov3D_precomp != NULL)
    {
        cov3D = cov3D_precomp + idx * 6;
    }
    else
    {
    computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
    cov3D = cov3Ds + idx * 6;
    }

    float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);

    float det = (cov.x * cov.z - cov.y * cov.y);
    if (det == 0.0f)
        continue;
    float det_inv = 1.0f / det;
    float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };

    float mid = 0.5f * (cov.x + cov.z);
    float delta = mid * mid - det;
    delta = fmax(delta, 0.1f);
    float sqrt_delta = sqrt(delta);
    float lambda1 = mid + sqrt_delta;
    float lambda2 = mid - sqrt_delta;
    float my_radius = ceil(3.0f * sqrt(fmax(lambda1, lambda2)));
    float2 point_image = {cpu_rasterizer_ndc2Pix(p_proj.x, W), cpu_rasterizer_ndc2Pix(p_proj.y, H)};
    uint2 rect_min, rect_max;
    cpu_rasterizer_getRect(point_image, (int)my_radius, &rect_min, &rect_max, grid);
    if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
        continue;
    if (colors_precomp == NULL)
    {
        float color[3];
        float3 campos = *cam_pos;
        float3 result = computeColorFromSH(idx, D, M, (float3*)orig_points, campos, shs, clamped);
        rgb[idx * NUM_CHANNELS + 0] = result.x;
        rgb[idx * NUM_CHANNELS + 1] = result.y;
        rgb[idx * NUM_CHANNELS + 2] = result.z;
    }

    depths[idx] = p_view.z;
    radii[idx] = (int)my_radius;
    points_xy_image[idx] = point_image;
    conic_opacity[idx].x = conic.x;
    conic_opacity[idx].y = conic.y;
    conic_opacity[idx].z = conic.z;
    conic_opacity[idx].w = opacities[idx];
    tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
}
}
#include <stdint.h>
#include <math.h>
//E4M3 Representation helper functions
static inline bfloat8 float_to_bfloat8(float val) {
    uint32_t u = *(uint32_t *)&val;
    int sign = u >> 31;
    int exp = (u >> 23) & 0xFF;
    int mantissa = u & 0x7FFFFF;
    
    // Handle zero (and flush subnormals to zero for simplicity)
    if(exp == 0) {
        return (bfloat8)(sign << 7);
    }
    // Handle NaN and Infinity
    if(exp == 0xFF) {
        // Represent as infinity: exponent = 0xF and mantissa = 0
        return (bfloat8)((sign << 7) | (0xF << 3));
    }
    
    // Compute new exponent in E4M3: new_exp = exp - 127 + 7
    int new_exp = exp - 127 + 7;
    if(new_exp >= 0xF) { // Overflow: clamp to infinity
        new_exp = 0xF;
        return (bfloat8)((sign << 7) | (new_exp << 3));
    } else if(new_exp <= 0) {
        // Underflow: attempt subnormal conversion.
        // For subnormals, the effective mantissa includes the implicit 1.
        int shift = (1 - new_exp) + (23 - 3);  // shift to position the bits into 3-bit field
        int sub_mant = (mantissa | (1 << 23)) >> shift;
        sub_mant = sub_mant & 0x7; // keep only 3 bits
        return (bfloat8)((sign << 7) | sub_mant);
    } else {
        // Normalized: truncate the mantissa to 3 bits.
        int new_mant = mantissa >> (23 - 3);
        return (bfloat8)((sign << 7) | (new_exp << 3) | (new_mant & 0x7));
    }
}

static inline float bfloat8_to_float(bfloat8 val) {
    int sign = val >> 7;
    int exp = (val >> 3) & 0xF;
    int mantissa = val & 0x7;
    
    uint32_t result;
    if(exp == 0) {
        if(mantissa == 0) {
            // Zero
            result = sign << 31;
            return *(float *)&result;
        } else {
            // Subnormal: value = (-1)^sign * (mantissa/8) * 2^(1-7)
            float sub = ((float)mantissa / 8.0f) * (1.0f / 64.0f);  // since 2^(1-7) == 1/64
            return sign ? -sub : sub;
        }
    } else if(exp == 0xF) {
        // Infinity (or overflow)
        result = (sign << 31) | (0xFF << 23);
        return *(float *)&result;
    } else {
        // Normalized: new exponent = exp - 7 + 127; shift mantissa into 23-bit field
        int new_exp = exp - 7 + 127;
        int new_mant = mantissa << (23 - 3);
        result = (sign << 31) | (new_exp << 23) | new_mant;
        return *(float *)&result;
    }
}

//---------------------------
// Pixel processing function using bfloat8 (E4M3)
//---------------------------

static inline void process_pixel_bfloat8(
    uint32_t sub_px, uint32_t sub_py, 
    int W, int H, 
    const uint32_t *point_list,
    const uint2 *ranges,
    const float2 *points_xy_image,
    const float *features,
    const float4 *conic_opacity,
    float *final_T,
    uint32_t *n_contrib,
    const float *bg_color,
    float *out_color,
    int tile_idx)
{
    uint32_t pix_id = W * sub_py + sub_px;
    float2 pixf = { (float)sub_px, (float)sub_py };

    bfloat8 T_bf8 = float_to_bfloat8(1.0f);
    uint32_t last_contributor = 0;
    bfloat8 C_bf8[NUM_CHANNELS] = {0};

    for (int idx = ranges[tile_idx].x; idx < ranges[tile_idx].y; idx++) {
        int coll_id = point_list[idx];

        // Convert image-space coordinates to bfloat8
        bfloat8 xy_x = float_to_bfloat8(points_xy_image[coll_id].x);
        bfloat8 xy_y = float_to_bfloat8(points_xy_image[coll_id].y);
        bfloat8 d_x = float_to_bfloat8(bfloat8_to_float(xy_x) - pixf.x);
        bfloat8 d_y = float_to_bfloat8(bfloat8_to_float(xy_y) - pixf.y);

        // Convert conic parameters to bfloat8
        bfloat8 con_o_x = float_to_bfloat8(conic_opacity[coll_id].x);
        bfloat8 con_o_y = float_to_bfloat8(conic_opacity[coll_id].y);
        bfloat8 con_o_z = float_to_bfloat8(conic_opacity[coll_id].z);
        bfloat8 con_o_w = float_to_bfloat8(conic_opacity[coll_id].w);

        // Compute power in bfloat8
        bfloat8 power = float_to_bfloat8(
            -0.5f * (
                bfloat8_to_float(con_o_x) * bfloat8_to_float(d_x) * bfloat8_to_float(d_x) +
                bfloat8_to_float(con_o_z) * bfloat8_to_float(d_y) * bfloat8_to_float(d_y)
            )
            - ( bfloat8_to_float(con_o_y) * bfloat8_to_float(d_x) * bfloat8_to_float(d_y) )
        );
        

        if (bfloat8_to_float(power) > 0.0f)
            continue;

        // Compute alpha using exp() in float, then convert to bfloat8.
        bfloat8 alpha_bf8 = float_to_bfloat8(fmin(0.99f, bfloat8_to_float(con_o_w) * exp(bfloat8_to_float(power))));
        if (bfloat8_to_float(alpha_bf8) < (1.0f / 255.0f))
            continue;

        // Update transparency: T = T * (1 - alpha)
        bfloat8 test_T = float_to_bfloat8(bfloat8_to_float(T_bf8) * (1.0f - bfloat8_to_float(alpha_bf8)));
        if (bfloat8_to_float(test_T) < 0.0001f)
            break;

        // Blend color contributions using bfloat8
        for (int ch = 0; ch < NUM_CHANNELS; ch++) {
            bfloat8 feat = float_to_bfloat8(features[coll_id * NUM_CHANNELS + ch]);
            bfloat8 contrib = float_to_bfloat8(
                bfloat8_to_float(feat) * bfloat8_to_float(alpha_bf8) * bfloat8_to_float(T_bf8)
            );
            C_bf8[ch] = float_to_bfloat8(bfloat8_to_float(C_bf8[ch]) + bfloat8_to_float(contrib));
        }

        // Update T and record the last contributor
        T_bf8 = test_T;
        last_contributor = coll_id;
    }

    // Write final transparency and color (convert back to float)
    final_T[pix_id] = bfloat8_to_float(T_bf8);
    n_contrib[pix_id] = last_contributor;

    for (int ch = 0; ch < NUM_CHANNELS; ch++) {
        bfloat8 bg_color_bf8 = float_to_bfloat8(bg_color[ch]);
        bfloat8 final_val = float_to_bfloat8(
            bfloat8_to_float(C_bf8[ch]) + bfloat8_to_float(T_bf8) * bfloat8_to_float(bg_color_bf8)
        );
        out_color[ch * H * W + pix_id] = bfloat8_to_float(final_val);
    }
}
#include <stdint.h>
#include <math.h>

// 8-bit float: E5M2 => 1 sign bit, 5 exponent bits (bias=15), 2 mantissa bits
typedef uint8_t float8_e5m2;

// Convert float -> E5M2 (truncation-based)
static inline float8_e5m2 float_to_e5m2(float val) {
    union {
        float f;
        uint32_t u;
    } v;
    v.f = val;

    int sign     = (v.u >> 31) & 0x1;
    int exp      = (v.u >> 23) & 0xFF;      // 8-bit exponent
    int mantissa =  v.u        & 0x7FFFFF;  // 23-bit mantissa

    // Case 1: Zero or (flush) subnormal
    if (exp == 0) {
        // Flush all subnormals to 0
        return (float8_e5m2)(sign << 7);
    }

    // Case 2: Inf or NaN
    if (exp == 0xFF) {
        // Infinity if mantissa == 0, else NaN
        if (mantissa == 0) {
            // Infinity
            return (float8_e5m2)((sign << 7) | (0x1F << 2)); // exp=31, mant=0
        } else {
            // NaN
            // Typically set mant=1 => quiet NaN
            return (float8_e5m2)((sign << 7) | (0x1F << 2) | 0x1);
        }
    }

    // Normal number
    // new_exp = old_exp - 127 + 15
    int new_exp = exp - 127 + 15;
    if (new_exp >= 0x1F) {
        // Overflow => Infinity
        return (float8_e5m2)((sign << 7) | (0x1F << 2));
    } else if (new_exp <= 0) {
        // Subnormal in E5M2. We can attempt partial subnormal:
        //   shift = (1 - new_exp) + (23 - 2)
        // or simply flush to 0. Let's do a basic subnormal approach:
        int shift = (1 - new_exp) + (23 - 2); // how many bits to shift mantissa
        // Mantissa includes implicit 1
        int sub_mant = ((1 << 23) | mantissa) >> shift;
        sub_mant &= 0x3;  // keep 2 bits
        if (sub_mant == 0) {
            // flush to zero
            return (float8_e5m2)(sign << 7);
        } else {
            // sign + subnormal mantissa
            return (float8_e5m2)((sign << 7) | sub_mant);
        }
    } else {
        // Normal
        int new_mant = mantissa >> (23 - 2); // keep top 2 bits
        new_mant &= 0x3;
        return (float8_e5m2)((sign << 7) | (new_exp << 2) | new_mant);
    }
}

// Convert E5M2 -> float
static inline float e5m2_to_float(float8_e5m2 val) {
    int sign     = (val >> 7) & 0x1;
    int exp      = (val >> 2) & 0x1F; // 5-bit exponent
    int mantissa =  val       & 0x3;  // 2-bit mantissa

    union {
        uint32_t u;
        float f;
    } v;

    // Case 1: exp=31 => Infinity or NaN
    if (exp == 0x1F) {
        // If mantissa=0 => Infinity, else => NaN
        v.u = (sign << 31) | (0xFF << 23) | (mantissa ? 0x1 : 0x0);
        return v.f;
    }

    // Case 2: exp=0 => zero or subnormal
    if (exp == 0) {
        if (mantissa == 0) {
            // Zero
            v.u = (sign << 31);
            return v.f;
        } else {
            // Subnormal => exponent = 1 - bias = 1 - 15 = -14
            // Value = sign * mantissa/4 * 2^(-14)
            float sub = (float)mantissa / 4.0f;   // /4 => shift for 2-bit mantissa
            sub *= powf(2.0f, -14.0f);
            return sign ? -sub : sub;
        }
    }

    // Normal
    // new_exp = exp - 15 + 127
    int new_exp = exp - 15 + 127;
    int new_mant = mantissa << (23 - 2); // shift 2 bits up into IEEE754 mantissa
    v.u = (sign << 31) | (new_exp << 23) | new_mant;
    return v.f;
}



static inline void process_pixel_e5m2(
    uint32_t sub_px, uint32_t sub_py,
    int W, int H,
    const uint32_t *point_list,
    const uint2 *ranges,
    const float2 *points_xy_image,
    const float *features,
    const float4 *conic_opacity,
    float *final_T,
    uint32_t *n_contrib,
    const float *bg_color,
    float *out_color,
    int tile_idx)
{
    uint32_t pix_id = W * sub_py + sub_px;
    float2 pixf = { (float)sub_px, (float)sub_py };

    // Start with T=1.0 in E5M2
    float8_e5m2 T_e5m2 = float_to_e5m2(1.0f);
    uint32_t last_contributor = 0;

    // Accumulated color in E5M2
    float8_e5m2 C_e5m2[NUM_CHANNELS];
    for (int ch = 0; ch < NUM_CHANNELS; ch++)
        C_e5m2[ch] = float_to_e5m2(0.0f);

    for (int idx = ranges[tile_idx].x; idx < ranges[tile_idx].y; idx++) {
        int coll_id = point_list[idx];

        // Convert geometry to E5M2 (WARNING: saturates if x>~57344)
        float8_e5m2 xy_x = float_to_e5m2(points_xy_image[coll_id].x);
        float8_e5m2 xy_y = float_to_e5m2(points_xy_image[coll_id].y);

        // d_x, d_y
        float dx_f = e5m2_to_float(xy_x) - pixf.x;
        float dy_f = e5m2_to_float(xy_y) - pixf.y;
        float8_e5m2 d_x = float_to_e5m2(dx_f);
        float8_e5m2 d_y = float_to_e5m2(dy_f);

        // Conic parameters in E5M2
        float8_e5m2 co_x = float_to_e5m2(conic_opacity[coll_id].x);  // diagonal x
        float8_e5m2 co_y = float_to_e5m2(conic_opacity[coll_id].y);  // cross term
        float8_e5m2 co_z = float_to_e5m2(conic_opacity[coll_id].z);  // diagonal y
        float8_e5m2 co_w = float_to_e5m2(conic_opacity[coll_id].w);  // opacity scale

        // power = -0.5*(co_x*d_x^2 + co_z*d_y^2) - (co_y*d_x*d_y)
        float power_f = -0.5f * ( e5m2_to_float(co_x) * dx_f * dx_f
                                + e5m2_to_float(co_z) * dy_f * dy_f )
                        - ( e5m2_to_float(co_y) * dx_f * dy_f );
        float8_e5m2 power_e5m2 = float_to_e5m2(power_f);

        // If power>0 => skip
        if (e5m2_to_float(power_e5m2) > 0.0f)
            continue;

        // alpha = min(0.99, co_w * exp(power))
        float alpha_f = fminf(0.99f, e5m2_to_float(co_w) * expf(e5m2_to_float(power_e5m2)));
        float8_e5m2 alpha_e5m2 = float_to_e5m2(alpha_f);

        // If alpha < ~1/255 => skip
        if (e5m2_to_float(alpha_e5m2) < (1.0f / 255.0f))
            continue;

        // test_T = T*(1-alpha)
        float test_T_f = e5m2_to_float(T_e5m2) * (1.0f - alpha_f);
        float8_e5m2 test_T_e5m2 = float_to_e5m2(test_T_f);
        if (e5m2_to_float(test_T_e5m2) < 0.0001f)
            break;

        // Accumulate color
        for (int ch = 0; ch < NUM_CHANNELS; ch++) {
            float feat_f = features[coll_id * NUM_CHANNELS + ch];
            float c_old  = e5m2_to_float(C_e5m2[ch]);
            float c_add  = feat_f * alpha_f * e5m2_to_float(T_e5m2);
            C_e5m2[ch]   = float_to_e5m2(c_old + c_add);
        }

        T_e5m2 = test_T_e5m2;
        last_contributor = coll_id;
    }

    // Convert final T and color back to float
    final_T[pix_id] = e5m2_to_float(T_e5m2);
    n_contrib[pix_id] = last_contributor;

    for (int ch = 0; ch < NUM_CHANNELS; ch++) {
        float c_f  = e5m2_to_float(C_e5m2[ch]);
        float t_f  = e5m2_to_float(T_e5m2);
        float outf = c_f + t_f * bg_color[ch];
        out_color[ch * H * W + pix_id] = outf;
    }
}

typedef uint16_t bfloat16;

// Helper function to convert float to bfloat16 (truncation-based)
static inline bfloat16 float_to_bfloat16(float val) {
    uint32_t *ival = (uint32_t *)&val; 
    return (bfloat16)(*ival >> 16); 
}

// Convert bfloat16 back to float
static inline float bfloat16_to_float(bfloat16 val) {
    uint32_t ival = ((uint32_t)val) << 16;  
    return *(float *)&ival;
}

static inline void process_pixel_bfloat16(
    uint32_t sub_px, uint32_t sub_py, 
    int W, int H, 
    const uint32_t *point_list,
    const uint2 *ranges,
    const float2 *points_xy_image,
    const float *features,
    const float4 *conic_opacity,
    float *final_T,
    uint32_t *n_contrib,
    const float *bg_color,
    float *out_color,
    int tile_idx)
{
    uint32_t pix_id = W * sub_py + sub_px;
    float2 pixf = {(float)sub_px, (float)sub_py};

    bfloat16 T_bf16 = float_to_bfloat16(1.0f);
    uint32_t last_contributor = 0;
    bfloat16 C_bf16[NUM_CHANNELS] = {0};

    for (int idx = ranges[tile_idx].x; idx < ranges[tile_idx].y; idx++) {
        int coll_id = point_list[idx];

        // Convert coordinates to bfloat16
        bfloat16 xy_x = float_to_bfloat16(points_xy_image[coll_id].x);
        bfloat16 xy_y = float_to_bfloat16(points_xy_image[coll_id].y);
        bfloat16 d_x = float_to_bfloat16(bfloat16_to_float(xy_x) - pixf.x);
        bfloat16 d_y = float_to_bfloat16(bfloat16_to_float(xy_y) - pixf.y);

        // Convert conic parameters to bfloat16
        bfloat16 con_o_x = float_to_bfloat16(conic_opacity[coll_id].x);
        bfloat16 con_o_y = float_to_bfloat16(conic_opacity[coll_id].y);
        bfloat16 con_o_z = float_to_bfloat16(conic_opacity[coll_id].z);
        bfloat16 con_o_w = float_to_bfloat16(conic_opacity[coll_id].w);

        // Compute power in bfloat16
        bfloat16 power = float_to_bfloat16(
            -0.5f * (bfloat16_to_float(con_o_x) * bfloat16_to_float(d_x) * bfloat16_to_float(d_x) +
                     bfloat16_to_float(con_o_z) * bfloat16_to_float(d_y) * bfloat16_to_float(d_y) +
                     bfloat16_to_float(con_o_y) * bfloat16_to_float(d_x) * bfloat16_to_float(d_y))
        );

        if (bfloat16_to_float(power) > 0.0f) continue;

        // Compute alpha using exp() in float but store as bfloat16
        bfloat16 alpha_bf16 = float_to_bfloat16(fmin(0.99f, bfloat16_to_float(con_o_w) * exp(bfloat16_to_float(power))));
        if (bfloat16_to_float(alpha_bf16) < (1.0f / 255.0f)) continue;

        // Compute transparency update
        bfloat16 test_T = float_to_bfloat16(bfloat16_to_float(T_bf16) * (1.0f - bfloat16_to_float(alpha_bf16)));
        if (bfloat16_to_float(test_T) < 0.0001f) break;

        // Blend color contributions using bfloat16
        for (int ch = 0; ch < NUM_CHANNELS; ch++) {
            bfloat16 feat = float_to_bfloat16(features[coll_id * NUM_CHANNELS + ch]);
            bfloat16 contrib = float_to_bfloat16(bfloat16_to_float(feat) * bfloat16_to_float(alpha_bf16) * bfloat16_to_float(T_bf16));
            C_bf16[ch] = float_to_bfloat16(bfloat16_to_float(C_bf16[ch]) + bfloat16_to_float(contrib));
        }

        // Store new T value
        T_bf16 = test_T;
        last_contributor = coll_id;
    }

    // Convert final T and color to float before writing output
    final_T[pix_id] = bfloat16_to_float(T_bf16);
    n_contrib[pix_id] = last_contributor;

    for (int ch = 0; ch < NUM_CHANNELS; ch++) {
        bfloat16 bg_color_bf16 = float_to_bfloat16(bg_color[ch]);
        bfloat16 final_val = float_to_bfloat16(bfloat16_to_float(C_bf16[ch]) + bfloat16_to_float(T_bf16) * bfloat16_to_float(bg_color_bf16));
        out_color[ch * H * W + pix_id] = bfloat16_to_float(final_val);
    }
}



static inline void process_pixel(
    uint32_t sub_px, uint32_t sub_py, 
    int W, int H, 
    const uint32_t *point_list,
    const uint2 *ranges,
    const float2 *points_xy_image,
    const float *features,
    const float4 *conic_opacity,
    float *final_T,
    uint32_t *n_contrib,
    const float *bg_color,
    float *out_color,
    int tile_idx)
{
    uint32_t pix_id = W * sub_py + sub_px;
    float2 pixf = {(float)sub_px, (float)sub_py};

    float T = 1.0f;
    uint32_t contributor = 0;
    uint32_t last_contributor = 0;
    float C[NUM_CHANNELS] = {0.0f};

    for (int idx = ranges[tile_idx].x; idx < ranges[tile_idx].y; idx++) {
        contributor++;
        int coll_id = point_list[idx];
        float2 xy = points_xy_image[coll_id];
        float2 d = {xy.x - pixf.x, xy.y - pixf.y};
        float4 con_o = conic_opacity[coll_id];

        // Power calculation
        float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
        if (power > 0.0f) continue;

        // Alpha calculation
        float alpha = fmin(0.99f, con_o.w * exp(power));
        if (alpha < 1.0f / 255.0f) continue;
        //Transparency and threshholding
        float test_T = T * (1 - alpha);
        if (test_T < 0.0001f) break;

        // Fixed blending order
        for (int ch = 0; ch < NUM_CHANNELS; ch++) {
            C[ch] += features[coll_id * NUM_CHANNELS + ch] * alpha * T;
        }

        T = test_T;
        last_contributor = contributor;
    }

    final_T[pix_id] = T;
    n_contrib[pix_id] = last_contributor;

    for (int ch = 0; ch < NUM_CHANNELS; ch++) {
        out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
    }
}
// ---------------------------------------------------------------------
// Rendering function that calls the selected pixel processor
// ---------------------------------------------------------------------
void renderCPU(
    const dim3 grid,
    const uint2 *ranges,
    const uint32_t *point_list,
    int W, int H,
    const float2 *points_xy_image,
    const float *features,
    const float4 *conic_opacity,
    float *final_T,
    uint32_t *n_contrib,
    const float *bg_color,
    float *out_color)
{
    initialize_process_pixel_mode();  // Ensure correct processing mode is set

    uint32_t vertical_blocks = (H + BLOCK_Y - 1) / BLOCK_Y;
    uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
    const int PE_SIZE = 16;

    TIME_START(render);

    double total_parallel_render_time = 0.0;
    double total_tile_overhead = 0.0;

    for (uint32_t tile_y = 0; tile_y < vertical_blocks; tile_y++) {
        for (uint32_t tile_x = 0; tile_x < horizontal_blocks; tile_x++) {
            double tile_start = get_time();
            double tile_max_group_duration = 0.0;

            uint2 pix_min = {tile_x * BLOCK_X, tile_y * BLOCK_Y};
            uint2 pix_max = {
                (uint32_t)min_int(pix_min.x + BLOCK_X, W),
                (uint32_t)min_int(pix_min.y + BLOCK_Y, H)
            };

            for (uint32_t py = pix_min.y; py < pix_max.y; py += PE_SIZE) {
                for (uint32_t px = pix_min.x; px < pix_max.x; px += PE_SIZE) {
                    double pixel_group_start = get_time();
                    
                    for (uint32_t sub_py = py; sub_py < min_int(py + PE_SIZE, pix_max.y); ++sub_py) {
                        for (uint32_t sub_px = px; sub_px < min_int(px + PE_SIZE, pix_max.x); ++sub_px) {
                            selected_process_pixel(
                                sub_px, sub_py,
                                W, H,
                                point_list, ranges,
                                points_xy_image, features, conic_opacity,
                                final_T, n_contrib,
                                bg_color, out_color,
                                tile_y * horizontal_blocks + tile_x);
                        }
                    }

                    double pixel_group_end = get_time();
                    double pixel_group_duration = pixel_group_end - pixel_group_start;
                    tile_max_group_duration = fmax(tile_max_group_duration, pixel_group_duration);
                }
            }

            double tile_end = get_time();
            total_parallel_render_time += tile_max_group_duration;
            total_tile_overhead += (tile_end - tile_start - tile_max_group_duration);
        }
    }

    TIME_END(render, "Rendering Completed");

    printf("[RENDER-FINE] Parallel-simulated total render time: %f s\n", total_parallel_render_time);
    printf("[RENDER-FINE] tile overhead: %f s\n", total_tile_overhead);
}

// void renderCPU_original(
//     const dim3 grid,
//     const uint2 *ranges,
//     const uint32_t *point_list,
//     int W, int H,
//     const float2 *points_xy_image,
//     const float *features,
//     const float4 *conic_opacity,
//     float *final_T,
//     uint32_t *n_contrib,
//     const float *bg_color,
//     float *out_color)
// {
//     uint32_t vertical_blocks = (H + BLOCK_Y - 1) / BLOCK_Y;
// 	uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
    
//     for (uint32_t tile_y = 0; tile_y < vertical_blocks; tile_y++)
//     {
//         for (uint32_t tile_x = 0; tile_x < horizontal_blocks; tile_x++)
//         {
//             uint2 pix_min = {tile_x * BLOCK_X, tile_y * BLOCK_Y};
//             uint2 pix_max = {static_cast<uint32_t>(min_int(pix_min.x + BLOCK_X, W)), 
//                 static_cast<uint32_t>(min_int(pix_min.y + BLOCK_Y, H))};
//             uint2 range = ranges[tile_y * horizontal_blocks + tile_x];
//             int toDo = range.y - range.x;

            
//             for (uint32_t pix_y = pix_min.y; pix_y < pix_max.y; pix_y++)
//             {
//                 for (uint32_t pix_x = pix_min.x; pix_x < pix_max.x; pix_x++)
//                 {
//                     uint32_t pix_id = W * pix_y + pix_x;
//                     float2 pixf = {(float)pix_x, (float)pix_y};

//                     float T = 1.0f;
//                     uint32_t contributor = 0;
//                     uint32_t last_contributor = 0;
//                     float C[NUM_CHANNELS] = {0.0f};

                    
//                     for (int idx = range.x; idx < range.y; idx++)
//                     {
//                         contributor++;
//                         int coll_id = point_list[idx];

//                         float2 xy = points_xy_image[coll_id];
//                         float2 d = {xy.x - pixf.x, xy.y - pixf.y};
//                         float4 con_o = conic_opacity[coll_id];
//                         float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
//                         if (power > 0.0f)
//                             continue;

//                         float alpha = fmin(0.99f, con_o.w * exp(power));
//                         if (alpha < 1.0f / 255.0f)
//                             continue;
//                         float test_T = T * (1 - alpha);
//                         if (test_T < 0.0001f)
//                             break;

//                         for (int ch = 0; ch < NUM_CHANNELS; ch++)
//                             C[ch] += features[coll_id * NUM_CHANNELS + ch] * alpha * T;

//                         T = test_T;
//                         last_contributor = contributor;
//                     }

//                     final_T[pix_id] = T;
//                     n_contrib[pix_id] = last_contributor;
//                     for (int ch = 0; ch < NUM_CHANNELS; ch++)
//                         out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
//                 }
//             }
//         }
//     }
// }



// ---------------------------------------------------------------------
// Top-level timing for each image
// ---------------------------------------------------------------------
void cpu_rasterizer_render(
    const dim3 grid,
    const uint2 *ranges,
    const uint32_t *point_list,
    int W, int H,
    const float2 *means2D,
    const float *colors,
    const float4 *conic_opacity,
    float *final_T,
    uint32_t *n_contrib,
    const float *bg_color,
    float *out_color)
{
    TIME_START(render);

    renderCPU_original(
        grid,
        ranges,
        point_list,
        W, H,
        means2D,
        colors,
        conic_opacity,
        final_T,
        n_contrib,
        bg_color,
        out_color
    );

    TIME_END(render, "Rendering time (simulated 16x16 parallel)");

    // Print partial sums for exclusive times so far
    printf(
        "\n[RENDER-FINE] Simulated Parallel Execution => tileOverhead=%.4f s, pixelOverhead=%.4f s, alpha=%.4f s\n",
        g_tile_excl, g_pixel_excl, g_alpha
    );
    fflush(stdout);
}




// ---------------------------------------------------------------------
// Preprocessing with top-level timing
// ---------------------------------------------------------------------
void cpu_rasterizer_preprocess(
    int P, int D, int M,
    const float *means3D,
    const float3 *scales,
    const float scale_modifier,
    const float4 *rotations,
    const float *opacities,
    const float *shs,
    int *clamped,
    const float *cov3D_precomp,
    const float *colors_precomp,
    const float *viewmatrix,
    const float *projmatrix,
    const float3 *cam_pos,
    const int W, int H,
    const float tan_fovx, float tan_fovy,
    const float focal_x, const float focal_y,
    int *radii,
    float2 *means2D,
    float *depths,
    float *cov3Ds,
    float *rgb,
    float4 *conic_opacity,
    const dim3 grid,
    uint32_t *tiles_touched,
    int prefiltered)
{
    TIME_START(preprocess);

    preprocessCPU(
        P, D, M,
        means3D,
        scales,
        scale_modifier,
        rotations,
        opacities,
        shs,
        clamped,
        cov3D_precomp,
        colors_precomp,
        viewmatrix,
        projmatrix,
        cam_pos,
        W, H,
        tan_fovx, tan_fovy,
        focal_x, focal_y,
        radii,
        means2D,
        depths,
        cov3Ds,
        rgb,
        conic_opacity,
        grid,
        tiles_touched,
        prefiltered
    );

    TIME_END(preprocess, "Preprocessing time");
}